{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Manual Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "w1 = torch.randn((3,4))\n",
    "w2 = torch.randn((4,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "u0 = torch.randn((1,3))\n",
    "u0_sample1 = torch.randn((1,4))\n",
    "u0_sample2 = torch.randn((1,4))\n",
    "u0_sample3 = torch.randn((1,4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "u1 = u0 @ w1\n",
    "a1 = F.relu(u1)\n",
    "\n",
    "u2 = a1 @ w2\n",
    "a2 = F.relu(u2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 1.5046, -0.9168, -0.1071]]),\n",
       " tensor([[ 0.7525, -2.2559,  1.2341, -1.9694]]),\n",
       " tensor([[-1.7027,  0.1315,  0.5429, -1.1076, -2.6565]]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "u0, u1, u2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 1, 0, 0],\n",
       "        [0, 0, 0, 0, 0],\n",
       "        [0, 1, 1, 0, 0],\n",
       "        [0, 0, 0, 0, 0]], dtype=torch.int32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a1, a2\n",
    "bin_a1 = a1.bool()\n",
    "bin_a2 = a2.bool()\n",
    "bin_a1, bin_a2\n",
    "\n",
    "torch.ger(bin_a1.view(-1).int(), bin_a2.view(-1).int())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ True, False,  True, False]]),\n",
       " tensor([[False,  True,  True, False, False]]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bin_a1, bin_a2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'l0' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-dda50510556a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml0\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mw0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'l0' is not defined"
     ]
    }
   ],
   "source": [
    "l0 @ w0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(0):\n",
    "    print('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[1] + [2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "l1 = torch.rand(4,4) - 0.5\n",
    "l2 = torch.rand(4,4) - 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.4650,  0.4306,  0.3862, -0.3515],\n",
       "         [ 0.3349, -0.3047,  0.4430,  0.1848],\n",
       "         [-0.2227, -0.1486,  0.4093,  0.3325],\n",
       "         [ 0.4175, -0.2552, -0.0873, -0.2341]]),\n",
       " tensor([[-0.4936,  0.3309, -0.3928, -0.4395],\n",
       "         [ 0.3285, -0.4350,  0.0719, -0.4168],\n",
       "         [ 0.3718, -0.1376, -0.1727,  0.3263],\n",
       "         [ 0.2117,  0.0024,  0.0549,  0.1020]]))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l1, l2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.4306, 0.3862, 0.0000],\n",
       "        [0.3349, 0.0000, 0.4430, 0.1848],\n",
       "        [0.0000, 0.0000, 0.4093, 0.3325],\n",
       "        [0.4175, 0.0000, 0.0000, 0.0000]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.relu(l1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.9678,  0.5911,  1.7775, -1.3768]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = torch.randn((1,4))\n",
    "i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.2227,  0.3239,  1.4834,  0.6823]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h1 = i @ l1\n",
    "h1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.4060, -0.7480,  0.2848,  0.9559]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h2 = h1 @ l2\n",
    "h2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nupic.research.frameworks.dynamic_sparse import networks\n",
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# net = networks.MLP()\n",
    "# net = networks.MLPHeb(config=dict(\n",
    "#   input_size=784,\n",
    "#   num_classes=10,\n",
    "#   hidden_sizes=[100,100,100],\n",
    "#   use_kwinners=False,\n",
    "# ))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = networks.MLPHeb(config=dict(\n",
    "  input_size=3,\n",
    "  num_classes=2,\n",
    "  hidden_sizes=[4, 5],\n",
    "  use_kwinners=False,\n",
    "))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Linear-1                    [-1, 4]              16\n",
      "              ReLU-2                    [-1, 4]               0\n",
      "            Linear-3                    [-1, 5]              25\n",
      "              ReLU-4                    [-1, 5]               0\n",
      "            Linear-5                    [-1, 2]              12\n",
      "================================================================\n",
      "Total params: 53\n",
      "Trainable params: 53\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 0.00\n",
      "Params size (MB): 0.00\n",
      "Estimated Total Size (MB): 0.00\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(net, input_size=(1,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 3])\n",
      "torch.Size([5, 4])\n",
      "torch.Size([2, 5])\n"
     ]
    }
   ],
   "source": [
    "for m in net.modules():\n",
    "    if isinstance(m, nn.Linear):\n",
    "        print(m.weight.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "m2 = m.weight.data.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 1, 0, 1, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 1, 0, 1, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.ger((m2 > 0).int().view(-1), (m2 > 0).int().view(-1)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.3507, -0.4252],\n",
       "        [-0.3169,  0.2901],\n",
       "        [-0.2204,  0.2198],\n",
       "        [-0.1902, -0.0727],\n",
       "        [-0.2340, -0.1836]])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
